'''
* This is the projet for Brtc LlmOps Platform
* @Author Leon-liao <liaosiliang@alltman.com>
* @Description //TODO 
* @File: dataset_service.py
* @Time: 2025/9/25
* @All Rights Reserve By Brtc
'''
import logging
from dataclasses import dataclass
from uuid import UUID
from injector import inject
from sqlalchemy.sql.expression import desc
from internal.model import Dataset, Segment, DatasetQuery, AppDatasetJoin, Account
from internal.schema.dataset_schema import(CreateDataReq,
                                           UpdateDatasetReq,
                                           GetDatasetWithPageReq,
                                           HitReq)

from internal.entity.dataset_entity import DEFAULT_DATASET_DESCRIPTION_FORMATER
from internal.exception.exception import ValidateErrorException, NotFoundException, FailException
from internal.lib.helper import datetime_to_timestamp
from internal.task.dataset_task import delete_dataset
from pkg.paginator.paginator import Paginator
from pkg.sqlachemy.sqlalchemy import SQLAlchemy
from .base_service import BaseService
from .retrieval_service import RetrievalService



@inject
@dataclass
class DatasetService(BaseService):
    """知识库服务"""
    db:SQLAlchemy
    retrieval_servcie:RetrievalService

    def create_dataset(self, req:CreateDataReq, account:Account)->Dataset:
        """根据传递的请求信息创建知识库的数据库记录"""
        #1、检测该账号下是否存在同名知识库
        dataset = self.db.session.query(Dataset).filter_by(
            account_id = account.id,
            name = req.name.data
        ).one_or_none()

        if dataset:
            raise ValidateErrorException(f"该知识库{req.name.data}已经存在！！")
        #2、检测是否传递了描述信息， 如果没有则需要补充上
        if req.description.data is None or req.description.data.strip() == "":
            req.description.data = DEFAULT_DATASET_DESCRIPTION_FORMATER.format(name = req.name.data)
        #3、创建知识库记录并返回
        return self.create(
            Dataset,
            account_id = account.id,
            name = req.name.data,
            icon = req.icon.data,
            description = req.description.data,
        )


    def get_dataset(self, dataset_id:UUID, account:Account)->Dataset:
        """根据传递的请求信息创建知识库的数据库记录"""
        dataset = self.get(Dataset, dataset_id)

        if dataset is None or dataset.account_id != account.id:
            raise ValidateErrorException(f"该知识库不存在！！")

        return dataset


    def update_dataset(self, dataset_id:UUID, req:UpdateDatasetReq, account:Account)->Dataset:
        """根据传入的知识库id + 信息更新数据库里面的知识库 """
        #1、提取请求并校验
        dataset = self.get(Dataset, dataset_id)

        if dataset is None or dataset.account_id != account.id:
            raise ValidateErrorException(f"该知识库不存在！！")
        # 2、检测修改后是否会出出现重名的情况
        check_dataset = self.db.session.query(Dataset).filter(
            Dataset.account_id == account.id,
            Dataset.name == req.name.data,
            Dataset.id !=dataset_id
        ).one_or_none()
        if check_dataset:
            raise ValidateErrorException(f"该知识库名称{req.name.data}已经存在请修改！")
        #3、校验描述信息是否为空， 如果为空则人为设置
        if req.description.data is None or req.description.data.strip() == "":
            req.description.data = DEFAULT_DATASET_DESCRIPTION_FORMATER.format(name=req.name.data)

        #4、更新数据
        self.update(
            dataset,
            name = req.name.data,
            icon = req.icon.data,
            description = req.description.data,
        )
        return dataset


    def get_dataset_with_page(self, req:GetDatasetWithPageReq, account:Account)->tuple[list[Dataset], Paginator]:
        """根据传递的信息获取知识库列表分页数据 """
        # 1、构建分页器
        paginator = Paginator(db=self.db, req=req)
        # 2、构建筛选器
        filters = [Dataset.account_id == account.id]
        if req.search_word.data:
            filters.append(Dataset.name.ilike(f"%{req.search_word.data}%"))
        #3、执行分页并获取数据
        datasets = paginator.paginate(
            self.db.session.query(Dataset).filter(*filters).order_by(desc("created_at"))
        )
        return datasets, paginator


    def hit(self, dataset_id:UUID, req:HitReq, account:Account)->list[dict]:
        """根据传递的知识库id + 请求执行召回测试"""
        #1、检测知识库是否存在并校验
        dataset = self.get(Dataset, dataset_id)
        if dataset is None or dataset.account_id != account.id:
            raise NotFoundException("该知识库不存在")
        #2、调用检索服务执行检索
        lc_documents = self.retrieval_servcie.search_in_datasets(
            dataset_ids=[dataset_id],
            **req.data
        )
        lc_document_dict = {str(lc_document.metadata["segment_id"]): lc_document for lc_document in lc_documents}
        #3、根据检索到的 数据查询对应的片段信息
        segments = self.db.session.query(Segment).filter(
            Segment.id.in_([str(lc_document.metadata["segment_id"])for lc_document in lc_documents])
        ).all()
        segment_dict = {str(segment.id):segment for segment in segments}
        #4、排序片段数据
        sorted_segments = [
            segment_dict[str(lc_document.metadata["segment_id"])]
            for lc_document in lc_documents if str(lc_document.metadata["segment_id"]) in segment_dict
        ]

        #5、组装数据
        hit_result = []
        for segment in sorted_segments:
            document = segment.document
            upload_file = document.upload_file
            hit_result.append({
                "id": segment.id,
                "document": {
                    "id": document.id,
                    "name": document.name,
                    "extension": upload_file.extension,
                    "mime_type": upload_file.mime_type,
                },
                "dataset_id": segment.dataset_id,
                "score": lc_document_dict[str(segment.id)].metadata["score"],
                "position": segment.position,
                "content": segment.content,
                "keywords": segment.keywords,
                "character_count": segment.character_count,
                "token_count": segment.token_count,
                "hit_count": segment.hit_count,
                "enabled": segment.enabled,
                "disabled_at": datetime_to_timestamp(segment.disabled_at),
                "status": segment.status,
                "error": segment.error,
                "updated_at": datetime_to_timestamp(segment.updated_at),
                "created_at": datetime_to_timestamp(segment.created_at),
            })
        return hit_result


    def get_dataset_queries(self, dataset_id: UUID, account:Account) -> list[DatasetQuery]:
        """根据传递的知识库id获取最近的10条查询记录"""
        # 1.获取知识库并校验权限
        dataset = self.get(Dataset, dataset_id)
        if dataset is None or dataset.account_id != account.id:
            raise NotFoundException("该知识库不存在")
        # 2.调用知识库查询模型查找最近的10条记录
        dataset_queries = self.db.session.query(DatasetQuery).filter(
            DatasetQuery.dataset_id == dataset_id,
        ).order_by(desc("created_at")).limit(10).all()
        return dataset_queries


    def delete_dataset(self, dataset_id: UUID, account:Account):
        """根据传递的知识库id删除知识库信息，涵盖知识库底下的所有文档、片段、关键词，以及向量数据库里存储的数据"""
        # 1.获取知识库并校验权限
        dataset = self.get(Dataset, dataset_id)
        if dataset is None or dataset.account_id != account.id:
            raise NotFoundException("该知识库不存在")
        try:
            # 2.删除知识库基础记录以及知识库和应用关联的记录
            self.delete(dataset)
            with self.db.auto_commit():
                self.db.session.query(AppDatasetJoin).filter(
                    AppDatasetJoin.dataset_id == dataset_id,
                ).delete()
            # 3.调用异步任务执行后续的操作
            delete_dataset.delay(dataset_id)
        except Exception as e:
            logging.exception(f"删除知识库失败, dataset_id: {dataset_id}, 错误信息: {str(e)}")
            raise FailException("删除知识库失败，请稍后重试")














